#!/usr/bin/env perl

# Copyright 2014 The Souper Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

use warnings;
use strict;
use Redis;
use Getopt::Long;
use File::Temp;
use Time::HiRes;

my $llvmas = "@LLVM_BINDIR@/llvm-as";
my $llvmopt = "@LLVM_BINDIR@/opt";
my $llvmdis = "@LLVM_BINDIR@/llvm-dis";
my $check = "@CMAKE_BINARY_DIR@/souper-check";
my $reducer = "@CMAKE_BINARY_DIR@/reduce";
my $triager = "@CMAKE_BINARY_DIR@/py_souper2llvm";

sub runit ($) {
    my $cmd = shift;
    my $res = (system "$cmd");
    return $? >> 8;
}

sub usage() {
    print <<'END';
Options:
  -nonNeg       Infer nonNegative DFA
  -neg          Infer Neg DFA
  -knownbits    Infer known bits
  -powerTwo     Infer power of two dfa
  -nonZero      Infer non Zero dfa
  -signBits     Infer number of sign bits dfa
  -demandedbits Infer demanded bits
  -merge        Merge optimizations that differ only in bitwidths and constants
  -noopts       Dump not-optimizations instead of optimizations
  -parse        Ensure that each LHS in the cache parses as Souper
  -sort=size|sprofile|dprofile|combined
                Sort optimizations by increasing size (default), decreasing
                static profile count, or decreasing dynamic profile count
  -raw          Dump all keys and values and exit, ignoring all other options
  -reduce       Attempt to reduce the size of each optimization
  -triage       Attempt to avoid reporting optimizations that LLVM can do
  -weaken       Weaken dataflow facts
  -verbose
  -verify       Verify each optimization
END
    exit -1;
}

my $RAW = 0;
my $REDISPORT = 6379;
my $NOOPTS = 0;
my $SORT = "size";
my $REDUCE = 0;
my $NONNEG = 0;
my $NEG = 0;
my $KNOWN = 0;
my $POWERTWO = 0;
my $NONZERO = 0;
my $SIGNBITS = 0;
my $DEMANDEDBITS = 0;
my $MERGE = 0;
my $PARSE = 0;
my $TRIAGE = 0;
my $WEAKEN = 0;
my $VERBOSE = 0;
my $VERIFY = 0;
GetOptions(
    "merge" => \$MERGE,
    "noopts" => \$NOOPTS,
    "parse" => \$PARSE,
    "redis-port=i" => \$REDISPORT,
    "knownbits" => \$KNOWN,
    "nonNeg" => \$NONNEG,
    "neg" => \$NEG,
    "powerTwo" => \$POWERTWO,
    "nonZero" => \$NONZERO,
    "signBits" => \$SIGNBITS,
    "demandedbits" => \$DEMANDEDBITS,
    "sort=s" => \$SORT,
    "raw"  => \$RAW,
    "reduce" => \$REDUCE,
    "triage" => \$TRIAGE,
    "weaken" => \$WEAKEN,
    "verbose" => \$VERBOSE,
    "verify" => \$VERIFY,
    ) or usage();
usage() unless ($SORT eq "size" || $SORT eq "sprofile" || $SORT eq "dprofile" ||
    $SORT eq "combined");

my $noopt_count=0;
my %values;
my %sprofiles;
my %dprofiles;
my %sprofile_locs;
my %dprofile_locs;
my %toprint;

my $r = Redis->new(server => "localhost:" . $REDISPORT);
$r->ping || die "no server?";
my @all_keys = $r->keys('*');

foreach my $opt (@all_keys) {
    $sprofiles{$opt} = 0;
    $dprofiles{$opt} = 0;
}

if ($RAW) {
    foreach my $opt (sort @all_keys) {
        my %h = $r->hgetall($opt);
        print "<$opt>\n";
        foreach my $kk (sort keys %h) {
            print "  <$kk> <$h{$kk}>\n";
        }
        my $result = $h{"result"};
        print "------------------------------------------------------\n\n";
    }
    exit 0;
}

print "; Inspecting ".scalar(@all_keys)." Redis values\n";

sub parse ($$) {
    (my $opt, my $RHS) = @_;
    (my $fh, my $tmpfn) = File::Temp::tempfile();
    print $fh $opt;
    $fh->flush();
    my $arg;
    if ($VERIFY) {
        $arg = "--infer-rhs";
    } else {
        $arg = "--parse-lhs-only";
    }
    open INF, "${check} $arg < $tmpfn 2>/dev/null |";
    my $output = "";
    my $success = 0;
    while (my $line = <INF>) {
        $success = 1 if ($line =~ /success/);
        next if ($line =~ /^;/);
        $output .= $line;
    }
    close INF;
    close $fh;
    unlink $tmpfn;
    if ($VERIFY) {
        die "expected '$RHS' but souper-check returned '$output'" unless
            ($output eq $RHS);
    } else {
        die "'$opt' does not parse as Souper" unless $success;
    }
}

sub known ($) {
    (my $opt) = @_;
    (my $fh, my $tmpfn) = File::Temp::tempfile();
    print $fh $opt;
    $fh->flush();
    my $arg = "-infer-known-bits";
    my $cmd = "${check} $arg < $tmpfn 2>/dev/null |";
    open INF, $cmd;
    while (my $line = <INF>) {
       print $line;
    }
    close INF;
    close $fh;
    unlink $tmpfn;
}

sub nonNeg ($) {
    (my $opt) = @_;
    (my $fh, my $tmpfn) = File::Temp::tempfile();
    print $fh $opt;
    $fh->flush();
    my $arg = "-infer-non-neg";
    my $cmd = "${check} $arg < $tmpfn 2>/dev/null |";
    open INF, $cmd;
    while (my $line = <INF>) {
       print $line;
    }
    close INF;
    close $fh;
    unlink $tmpfn;
}

sub neg ($) {
    (my $opt) = @_;
    (my $fh, my $tmpfn) = File::Temp::tempfile();
    print $fh $opt;
    $fh->flush();
    my $arg = "-infer-neg";
    my $cmd = "${check} $arg < $tmpfn 2>/dev/null |";
    open INF, $cmd;
    while (my $line = <INF>) {
       print $line;
    }
    close INF;
    close $fh;
    unlink $tmpfn;
}

sub powerTwo ($) {
    (my $opt) = @_;
    (my $fh, my $tmpfn) = File::Temp::tempfile();
    print $fh $opt;
    $fh->flush();
    my $arg = "-infer-power-two";
    my $cmd = "${check} $arg < $tmpfn 2>/dev/null |";
    open INF, $cmd;
    while (my $line = <INF>) {
       print $line;
    }
    close INF;
    close $fh;
    unlink $tmpfn;
}

sub nonZero ($) {
    (my $opt) = @_;
    (my $fh, my $tmpfn) = File::Temp::tempfile();
    print $fh $opt;
    $fh->flush();
    my $arg = "-infer-non-zero";
    my $cmd = "${check} $arg < $tmpfn 2>/dev/null |";
    open INF, $cmd;
    while (my $line = <INF>) {
       print $line;
    }
    close INF;
    close $fh;
    unlink $tmpfn;
}

sub signBits ($) {
    (my $opt) = @_;
    (my $fh, my $tmpfn) = File::Temp::tempfile();
    print $fh $opt;
    $fh->flush();
    my $arg = "-infer-sign-bits";
    my $cmd = "${check} $arg < $tmpfn 2>/dev/null |";
    open INF, $cmd;
    while (my $line = <INF>) {
       print $line;
    }
    close INF;
    close $fh;
    unlink $tmpfn;
}

sub demandedbits ($) {
    (my $opt) = @_;
    (my $fh, my $tmpfn) = File::Temp::tempfile();
    print $fh $opt;
    $fh->flush();
    my $arg = "-infer-demanded-bits";
    my $cmd = "${check} $arg < $tmpfn 2>/dev/null |";
    open INF, $cmd;
    while (my $line = <INF>) {
       print $line;
    }
    close INF;
    close $fh;
    unlink $tmpfn;
}

sub checkstr ($) {
    (my $s) = @_;
    (my $fh, my $tmpfn) = File::Temp::tempfile();
    print $fh $s;
    $fh->flush();
    open INF, "${check} < $tmpfn 2>/dev/null |";
    my $output = "";
    while (my $line = <INF>) {
        chomp $line;
        $output .= $line;
    }
    close INF;
    close $fh;
    unlink $tmpfn;
    return 1 if ($output =~ "LGTM");
    return 0;
}

sub add_sprofile($$) {
    (my $opt, my $href) = @_;
    my %h = %{$href};
    foreach my $k (keys %h) {
        $sprofile_locs{$opt}{$k} += $h{$k};
    }
    if (!$sprofile_locs{$opt}) {
        $sprofile_locs{$opt} = {};
    }
}

sub add_dprofile($$) {
    (my $opt, my $href) = @_;
    my %h = %{$href};
    foreach my $k (keys %h) {
        $dprofile_locs{$opt}{$k} += $h{$k};
    }
    if (!$dprofile_locs{$opt}) {
        $dprofile_locs{$opt} = {};
    }
}

my $xcnt = 0;
my $tagged = 0;
my $untagged = 0;
foreach my $opt (@all_keys) {
    # last if $xcnt++ > 2500;
    my %h = $r->hgetall($opt);
    my $result = $h{"result"};
    if (defined $h{"cache-infer-tag"}) {
	$tagged++;
    } else {
	$untagged++;
    }
    if (defined($result)) {
        parse($opt, $result) if $PARSE;
    } else {
        next;
    }
    my $sprofile = 0;
    my $dprofile = 0;
    my %sprofile_loc;
    my %dprofile_loc;
    foreach my $kk (keys %h) {
        if ($kk =~ /^sprofile (.*)$/) {
            my $count = $h{$kk};
            $sprofile += $count;
            $sprofile_loc{$1} += $count;
        }
        if ($kk =~ /^dprofile (.*)$/) {
            my $count = $h{$kk};
            $dprofile += $count;
            $dprofile_loc{$1} += $count;
        }
    }
    $opt .= $result;
    if ($result eq "") {
        $noopt_count++;
        $toprint{$opt} = 1 if $NOOPTS;
    } else {
        $toprint{$opt} = 1 if !$NOOPTS;
    }
    add_sprofile($opt, \%sprofile_loc);
    add_dprofile($opt, \%dprofile_loc);
    $sprofiles{$opt} = $sprofile;
    $dprofiles{$opt} = $dprofile;
}

print "; Discarding ${noopt_count} not-optimizations leaving ".
    scalar(keys %toprint)." optimizations\n";

sub replace($$) {
    (my $old, my $new) = @_;
    die if $new eq "";
    $sprofiles{$new} += $sprofiles{$old};
    $dprofiles{$new} += $dprofiles{$old};
    my %ss = %{$sprofile_locs{$old}};
    add_sprofile($new, \%ss);
    my %dd = %{$dprofile_locs{$old}};
    add_dprofile($new, \%dd);
    $toprint{$new} = 1;
    delete $toprint{$old};
}

sub remove($) {
    (my $opt) = @_;
    delete $toprint{$opt};
    delete $sprofiles{$opt};
    delete $dprofiles{$opt};
    delete $sprofile_locs{$opt};
    delete $dprofile_locs{$opt};
}

my $status_cnt;
my $status_opct;
my $status_total;

sub reset_status($) {
    (my $t) = @_;
    $status_total = $t;
    $status_opct = 0;
    $status_cnt = 0;
}

sub status() {
    print ".";
    $status_cnt++;
    my $pct = int(100.0*$status_cnt/$status_total);
    if ($pct > $status_opct) {
        $status_opct = $pct;
        print "$pct %\n";
    }
}

sub total_profile_count() {
    my $cnt = 0;
    foreach my $opt (keys %toprint) {
        $cnt += $sprofiles{$opt} + $dprofiles{$opt};
    }
    return $cnt;
}

if ($REDUCE) {
    print "; Reducing\n";
    my $tprofile = total_profile_count();

    my @keys = keys %toprint;
    reset_status(scalar(@keys)) if $VERBOSE;
    foreach my $opt (@keys) {
        (my $fh1, my $fn1) = File::Temp::tempfile();
        (my $fh2, my $fn2) = File::Temp::tempfile();
        print $fh1 $opt;
        close $fh1;
        close $fh2;
        if (runit ("$reducer < $fn1 > $fn2") != 0) {
            print "cannot reduce '$opt'\n";
            remove($opt);
            next;
        }
        open INF, "<$fn2" or die;
        my $new = "";
        while (my $line = <INF>) {
            $new .= $line;
        }
        close INF;
        unlink $fn1;
        unlink $fn2;
        if ($new eq "") {
            print "cannot reduce '$opt'\n";
            next;
        }
        replace($opt, $new);
        status() if $VERBOSE;
    }

    print "; After reducing there are ".scalar(keys %toprint)." optimizations\n";
    # die unless ($tprofile == total_profile_count());
}

sub grab_output($) {
    (my $cmd) = @_;
    open INF, $cmd or die;
    my @llvm = ();
    while (my $line = <INF>) {
        push @llvm, $line;
    }
    close INF;
    return \@llvm;
}

sub count_insns($) {
    (my $lref) = @_;
    my @l = @{$lref};
    my $in = 0;
    my $count = 0;
    foreach my $line (@l) {
        if ($line =~ /foo0:/) {
            $in = 1;
            next;
        }
        if ($line =~ /ret /) {
            $in = 0;
            next;
        }
        $count++ if $in;
    }
    # print "$count lines\n";
    return $count;
}

if ($TRIAGE) {
    print "; Triaging\n";

    my @keys = keys %toprint;
    reset_status(scalar(@keys)) if $VERBOSE;
    foreach my $opt (@keys) {
        (my $fh1, my $fn1) = File::Temp::tempfile();
        print $fh1 $opt;
        close $fh1;
        my @llvm1 = @{grab_output("$triager < $fn1 | $llvmas | $llvmdis |")};
        my @llvm2 = @{grab_output("$triager < $fn1 | $llvmas | $llvmopt -O3 | $llvmdis |")};
        #print "\n@llvm1\n\n";
        #print "\n@llvm2\n\n";
        remove($opt) unless count_insns(\@llvm1) <= count_insns(\@llvm2);
        #print "----------------------------\n";
        status() if $VERBOSE;
    }

    print "; After triaging there are ".scalar(keys %toprint)." optimizations\n";
}

my %weaken_insn = (
    "addnsw" => "add",
    "addnuw" => "add",
    "addnw" => "add",
    "subnsw" => "sub",
    "subnuw" => "sub",
    "subnw" => "sub",
    "mulnsw" => "mul",
    "mulnuw" => "mul",
    "mulnw" => "mul",
    "udivexact" => "udiv",
    "sdivexact" => "sdiv",
    "shlnsw" => "shl",
    "shlnuw" => "shl",
    "shlnw" => "shl",
    "lshrexact" => "lshr",
    "ashrexact" => "ashr",
    );

sub replace_nth($$) {
    (my $str, my $n) = @_;
    my $count = 0;
    my $in = 0;
    for (my $pos = 0; $pos < length($str); $pos++) {
	if (substr($str,$pos,1) eq "(") {
	    die if $in;
	    $in = 1;
	    next;
	}
	if (substr($str,$pos,1) eq ")") {
	    die unless $in;
	    $in = 0;
	    next;
	}
	if ($in) {
	    my $c = substr($str,$pos,1);
	    if ($c eq "0" || $c eq "1") {
		if ($count == $n) {
		    # print "found match $n at pos $pos\n";
		    substr($str,$pos,1) = "x";
		    return ($str, "bit");
		}
		$count++;
	    }
	    next;
	}
	my $rest = substr($str,$pos);
	foreach my $k (keys %weaken_insn) {
	    my $repl = $weaken_insn{$k};
	    next unless ($rest =~ s/^$k/$repl/);
	    if ($count == $n) {
		return (substr($str,0,$pos).$rest,"insn");
	    }
	    $count++;
	}
    }
    # print "did not find match $n\n";
    return (undef, "");
}

my %weaken_yes = ("insn" => 0, "bit" => "0");
my %weaken_no = ("insn" => 0, "bit" => "0");

if ($WEAKEN) {
    print "; Weakening\n";
    my @keys = keys %toprint;
    foreach my $opt (@keys) {
	my $orig = $opt;
	my $i = 0;
	while (1) {
	    (my $new, my $how) = replace_nth($opt, $i);
	    last unless defined $new;
	    if (checkstr($new)) {
		$opt = $new;
		$weaken_yes{$how}++;
	    } else {
		$weaken_no{$how}++;
		$i++;
	    }
	}
	die unless checkstr($opt);
	$opt =~ s/ \(x+\)//g;
	die unless checkstr($opt);
	replace($orig, $opt) if ($opt ne $orig);
    }
    my $yes = $weaken_yes{"bit"};
    my $no = $weaken_no{"bit"} + $weaken_yes{"bit"};
    print "; weakening bits worked $yes / $no times\n";
    $yes = $weaken_yes{"insn"};
    $no = $weaken_no{"insn"} + $weaken_yes{"insn"};
    print "; weakening insns worked $yes / $no times\n";
    print "; After weakening there are ".scalar(keys %toprint)." optimizations\n";
    # die unless ($tprofile == total_profile_count());
}

if ($MERGE) {
    print "; Merging\n";
    my $tprofile = total_profile_count();

    my @keys = keys %toprint;
    foreach my $opt (@keys) {
        my $new = $opt;
        $new =~ s/:i[0-9]+//g;
        $new =~ s/ [0-9]+/ C/g;
        replace($opt, $new);
    }
    print "; After merging there are ".scalar(keys %toprint)." optimizations\n";
    # die unless ($tprofile == total_profile_count());
}

my %sprofile_rank;
my %dprofile_rank;

if ($SORT eq "combined") {
    my $n=0;
    foreach my $opt (sort { $sprofiles{$b} <=> $sprofiles{$a} } @all_keys) {
        $sprofile_rank{$opt} = $n;
        $n++;
        # print "$sprofile{$opt} $opt\n\n";
    }
    $n=0;
    foreach my $opt (sort { $dprofiles{$b} <=> $dprofiles{$a} } @all_keys) {
        $dprofile_rank{$opt} = $n;
        $n++;
        # print "$dprofile{$opt} $opt\n\n";
    }
}

sub bylen { length $a <=> length $b }
sub bysprofile { $sprofiles{$b} <=> $sprofiles{$a} }
sub bydprofile { $dprofiles{$b} <=> $dprofiles{$a} }
sub byrank {
    return
        ($sprofile_rank{$a} + $dprofile_rank{$a}) <=>
        ($sprofile_rank{$b} + $dprofile_rank{$b});
}

my $byx = \&bylen;
$byx = \&bysprofile if ($SORT eq "sprofile");
$byx = \&bydprofile if ($SORT eq "dprofile");
$byx = \&byrank if ($SORT eq "combined");

print "\n\n";

foreach my $opt (sort $byx keys %toprint) {
    print "$opt";
    print "\n";
    nonNeg($opt) if $NONNEG;
    neg($opt) if $NEG;
    known($opt) if $KNOWN;
    powerTwo($opt) if $POWERTWO;
    nonZero($opt) if $NONZERO;
    signBits($opt) if $SIGNBITS;
    demandedbits($opt) if $DEMANDEDBITS;
    print "\n";
    print "; total static profile = $sprofiles{$opt}\n";
    my %h = %{$sprofile_locs{$opt}};
    foreach my $k (sort { $h{$b} <=> $h{$a} } keys %h) {
        next if ($k eq "");
        print "; sprofile $h{$k} \"$k\"\n";
    }
    print "; total dynamic profile = $dprofiles{$opt}\n";
    %h = %{$dprofile_locs{$opt}};
    foreach my $k (sort { $h{$b} <=> $h{$a} } keys %h) {
        next if ($k eq "");
        print "; dprofile $h{$k} \"$k\"\n";
    }
    print "------------------------------------------------------\n\n";
}

my $cnt = 0;
foreach my $opt (keys %toprint) {
    $cnt += $sprofiles{$opt};
}
print "; overall total static profile weight = $cnt\n";
print "; $tagged were tagged by cache_infer, $untagged were not\n";
